# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import time
from argparse import ArgumentParser

import torch
import torch.nn as nn

import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.algorithms.nas.pytorch import enas
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
                                       LRSchedulerCallback)
from utils import accuracy, reward_accuracy

logger = logging.getLogger('nni')


if __name__ == "__main__":
    parser = ArgumentParser("enas")
    parser.add_argument("--batch-size", default=128, type=int)
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
    parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
    parser.add_argument("--visualization", default=False, action="store_true")
    parser.add_argument("--v1", default=False, action="store_true")
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")
    mutator = None
    ctrl_kwargs = {}
    if args.search_for == "macro":
        model = GeneralNetwork()
        num_epochs = args.epochs or 310
    elif args.search_for == "micro":
        model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
        num_epochs = args.epochs or 150
        if args.v1:
            mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
        else:
            ctrl_kwargs = {"tanh_constant": 1.1}
    else:
        raise AssertionError

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

    if args.v1:
        trainer = enas.EnasTrainer(model,
                                   loss=criterion,
                                   metrics=accuracy,
                                   reward_function=reward_accuracy,
                                   optimizer=optimizer,
                                   callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                                   batch_size=args.batch_size,
                                   num_epochs=num_epochs,
                                   dataset_train=dataset_train,
                                   dataset_valid=dataset_valid,
                                   log_frequency=args.log_frequency,
                                   mutator=mutator)
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
    else:
        from nni.retiarii.trainer.pytorch.enas import EnasTrainer
        trainer = EnasTrainer(model,
                              loss=criterion,
                              metrics=accuracy,
                              reward_function=reward_accuracy,
                              optimizer=optimizer,
                              batch_size=args.batch_size,
                              num_epochs=num_epochs,
                              dataset=dataset_train,
                              log_frequency=args.log_frequency,
                              ctrl_kwargs=ctrl_kwargs)
        trainer.fit()
